
from typing import Callable
import gdb
import argparse
import pstats
import time
import cProfile
import pwndbg
import pwndbg.commands
import pwndbg.commands.context
import pwndbg.lib.cache
import pwndbg.aglib.regs
import pwndbg.aglib.vmmap
import pwndbg.commands.telescope

COUNT = 100

def run_benchmark(name: str, prefix: str, callback: Callable, count=COUNT) -> float:
    """
    Return:
        Average time to execute callback in seconds
    """
    profiler = cProfile.Profile()

    profiler.enable()
    
    for _ in range(count):
        callback()

    profiler.disable()

    pstats_file_name_base = f"{prefix}-{count}-{name}-{int(time.time())}"
    profiler.dump_stats(f"{pstats_file_name_base}.pstats")

    full_time = pstats.Stats(profiler).total_tt

    print(f"Time elapsed: {full_time}. Average time: {full_time / count}")

    return full_time / count

parser = argparse.ArgumentParser(
    description="""
Benchmark contexts.

Uses cProfile to instrument the code.

Experimentally, this can cause the code to run ~2.5 times slower than it's real speed.
However, you can use it to find relative performance before/after changes, and find functions that take longer than they should.
""",
)

parser.add_argument("name", type=str, help="Name placed into output pstats filename")

@pwndbg.commands.Command(parser, category=pwndbg.commands.CommandCategory.DEV)
def benchmark_context(name: str):

    # Context - clear cache
    ## Benchmark the `context` command, no caching

    def run_with_clear():
        pwndbg.commands.context.context.function()
        pwndbg.lib.cache.clear_caches()

    run_benchmark(
        name,
        "context-without-cache",
        run_with_clear
    )

    pwndbg.lib.cache.clear_caches()


    # Context - with cache
    ## Benchmark the `context` command with cache
    run_benchmark(
        name,
        "context-with-cache",
        # Bypass the decorator so cProfile/snakeviz sees things correctly
        lambda: pwndbg.commands.context.context.function()
    )

    pwndbg.lib.cache.clear_caches()

    # Step

    def run_with_step():
        gdb.execute("stepi")
        # Bypass the decorator so cProfile can see function stack correctly
        pwndbg.commands.context.context.function()

    run_benchmark(
        name,
        "step",
        run_with_step
    )

    # Regs

    def regs():
        gdb.execute("stepi")
        pwndbg.commands.context.context_regs()

    run_benchmark(
        name,
        "regs",
        regs
    )

    # Disasm

    def disasm():
        gdb.execute("stepi")
        pwndbg.commands.context.context_disasm()

    run_benchmark(
        name,
        "disasm",
        disasm
    )

    # Stack

    def stack():
        gdb.execute("stepi")
        pwndbg.commands.context.context_stack()

    run_benchmark(
        name,
        "stack",
        stack
    )



parser = argparse.ArgumentParser(
    description="""
Benchmark telescoping the entire stack
    """,
)
parser.add_argument("name", type=str, help="Name placed into output pstats filename")


@pwndbg.commands.Command(parser, category=pwndbg.commands.CommandCategory.DEV)
def benchmark_large_telescope(name: str):
    # Telescope entire stack
    stack_page = pwndbg.aglib.vmmap.find(pwndbg.aglib.regs.read_reg(pwndbg.aglib.regs.stack))
    start = stack_page.start
    len = stack_page.memsz
    
    def print_all_stack():
        pwndbg.commands.telescope.telescope(start, len // pwndbg.aglib.arch.ptrsize)

    run_benchmark(
        name,
        "all-stack",
        print_all_stack,
        4
    )

